import math
import os
from typing import TYPE_CHECKING, List, Optional

import huggingface_hub
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.memory_management.manager import MemoryManager
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.base_model import BaseModel
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from toolkit.samplers.custom_flowmatch_sampler import (
    CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.accelerator import unwrap_model
from optimum.quanto import freeze, QTensor
from toolkit.util.quantize import quantize, get_qtype, quantize_model

from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from .src.model import Flux2, Flux2Params
from .src.pipeline import Flux2Pipeline
from .src.autoencoder import AutoEncoder, AutoEncoderParams
from safetensors.torch import load_file, save_file
from PIL import Image
import torch.nn.functional as F

if TYPE_CHECKING:
    from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO

from .src.sampling import (
    batched_prc_img,
    batched_prc_txt,
    encode_image_refs,
    scatter_ids,
)

scheduler_config = {
    "base_image_seq_len": 256,
    "base_shift": 0.5,
    "max_image_seq_len": 4096,
    "max_shift": 1.15,
    "num_train_timesteps": 1000,
    "shift": 3.0,
    "use_dynamic_shifting": True,
}

MISTRAL_PATH = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
FLUX2_VAE_FILENAME = "ae.safetensors"
FLUX2_TRANSFORMER_FILENAME = "flux2-dev.safetensors"

HF_TOKEN = os.getenv("HF_TOKEN", None)


class Flux2Model(BaseModel):
    arch = "flux2"

    def __init__(
        self,
        device,
        model_config: ModelConfig,
        dtype="bf16",
        custom_pipeline=None,
        noise_scheduler=None,
        **kwargs,
    ):
        super().__init__(
            device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs
        )
        self.is_flow_matching = True
        self.is_transformer = True
        self.target_lora_modules = ["Flux2"]
        # control images will come in as a list for encoding some things if true
        self.has_multiple_control_images = True
        # do not resize control images
        self.use_raw_control_images = True

    # static method to get the noise scheduler
    @staticmethod
    def get_train_scheduler():
        return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)

    def get_bucket_divisibility(self):
        return 16

    def load_model(self):
        dtype = self.torch_dtype
        self.print_and_status_update("Loading Flux2 model")
        # will be updated if we detect a existing checkpoint in training folder
        model_path = self.model_config.name_or_path
        transformer_path = model_path

        self.print_and_status_update("Loading transformer")
        with torch.device("meta"):
            transformer = Flux2(Flux2Params())

        # use local path if provided
        if os.path.exists(os.path.join(transformer_path, FLUX2_TRANSFORMER_FILENAME)):
            transformer_path = os.path.join(
                transformer_path, FLUX2_TRANSFORMER_FILENAME
            )

        if not os.path.exists(transformer_path):
            # assume it is from the hub
            transformer_path = huggingface_hub.hf_hub_download(
                repo_id=model_path,
                filename=FLUX2_TRANSFORMER_FILENAME,
                token=HF_TOKEN,
            )

        transformer_state_dict = load_file(transformer_path, device="cpu")

        # cast to dtype
        for key in transformer_state_dict:
            transformer_state_dict[key] = transformer_state_dict[key].to(dtype)

        transformer.load_state_dict(transformer_state_dict, assign=True)

        transformer.to(self.quantize_device, dtype=dtype)

        if self.model_config.quantize:
            # patch the state dict method
            patch_dequantization_on_save(transformer)
            self.print_and_status_update("Quantizing Transformer")
            quantize_model(self, transformer)
            flush()
        else:
            transformer.to(self.device_torch, dtype=dtype)
        flush()

        if (
            self.model_config.layer_offloading
            and self.model_config.layer_offloading_transformer_percent > 0
        ):
            MemoryManager.attach(
                transformer,
                self.device_torch,
                offload_percent=self.model_config.layer_offloading_transformer_percent,
            )

        if self.model_config.low_vram:
            self.print_and_status_update("Moving transformer to CPU")
            transformer.to("cpu")

        self.print_and_status_update("Loading Mistral")

        text_encoder: Mistral3ForConditionalGeneration = (
            Mistral3ForConditionalGeneration.from_pretrained(
                MISTRAL_PATH,
                torch_dtype=dtype,
            )
        )
        text_encoder.to(self.device_torch, dtype=dtype)

        flush()

        if self.model_config.quantize_te:
            self.print_and_status_update("Quantizing Mistral")
            quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
            freeze(text_encoder)
            flush()

        if (
            self.model_config.layer_offloading
            and self.model_config.layer_offloading_text_encoder_percent > 0
        ):
            MemoryManager.attach(
                text_encoder,
                self.device_torch,
                offload_percent=self.model_config.layer_offloading_text_encoder_percent,
            )

        tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)

        self.print_and_status_update("Loading VAE")
        vae_path = self.model_config.vae_path

        if os.path.exists(os.path.join(model_path, FLUX2_VAE_FILENAME)):
            vae_path = os.path.join(model_path, FLUX2_VAE_FILENAME)

        if vae_path is None or not os.path.exists(vae_path):
            # assume it is from the hub
            vae_path = huggingface_hub.hf_hub_download(
                repo_id=model_path,
                filename=FLUX2_VAE_FILENAME,
                token=HF_TOKEN,
            )
        with torch.device("meta"):
            vae = AutoEncoder(AutoEncoderParams())

        vae_state_dict = load_file(vae_path, device="cpu")

        # cast to dtype
        for key in vae_state_dict:
            vae_state_dict[key] = vae_state_dict[key].to(dtype)

        vae.load_state_dict(vae_state_dict, assign=True)

        self.noise_scheduler = Flux2Model.get_train_scheduler()

        self.print_and_status_update("Making pipe")

        pipe: Flux2Pipeline = Flux2Pipeline(
            scheduler=self.noise_scheduler,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            vae=vae,
            transformer=None,
        )
        # for quantization, it works best to do these after making the pipe
        pipe.transformer = transformer

        self.print_and_status_update("Preparing Model")

        text_encoder = [pipe.text_encoder]
        tokenizer = [pipe.tokenizer]

        flush()
        # just to make sure everything is on the right device and dtype
        text_encoder[0].to(self.device_torch)
        text_encoder[0].requires_grad_(False)
        text_encoder[0].eval()
        pipe.transformer = pipe.transformer.to(self.device_torch)
        flush()

        # save it to the model class
        self.vae = vae
        self.text_encoder = text_encoder  # list of text encoders
        self.tokenizer = tokenizer  # list of tokenizers
        self.model = pipe.transformer
        self.pipeline = pipe
        self.print_and_status_update("Model Loaded")

    def get_generation_pipeline(self):
        scheduler = Flux2Model.get_train_scheduler()

        pipeline: Flux2Pipeline = Flux2Pipeline(
            scheduler=scheduler,
            text_encoder=unwrap_model(self.text_encoder[0]),
            tokenizer=self.tokenizer[0],
            vae=unwrap_model(self.vae),
            transformer=unwrap_model(self.transformer),
        )

        pipeline = pipeline.to(self.device_torch)

        return pipeline

    def generate_single_image(
        self,
        pipeline: Flux2Pipeline,
        gen_config: GenerateImageConfig,
        conditional_embeds: PromptEmbeds,
        unconditional_embeds: PromptEmbeds,
        generator: torch.Generator,
        extra: dict,
    ):
        gen_config.width = (
            gen_config.width // self.get_bucket_divisibility()
        ) * self.get_bucket_divisibility()
        gen_config.height = (
            gen_config.height // self.get_bucket_divisibility()
        ) * self.get_bucket_divisibility()

        control_img_list = []
        if gen_config.ctrl_img is not None:
            control_img = Image.open(gen_config.ctrl_img)
            control_img = control_img.convert("RGB")
            control_img_list.append(control_img)
        elif gen_config.ctrl_img_1 is not None:
            control_img = Image.open(gen_config.ctrl_img_1)
            control_img = control_img.convert("RGB")
            control_img_list.append(control_img)
        if gen_config.ctrl_img_2 is not None:
            control_img = Image.open(gen_config.ctrl_img_2)
            control_img = control_img.convert("RGB")
            control_img_list.append(control_img)
        if gen_config.ctrl_img_3 is not None:
            control_img = Image.open(gen_config.ctrl_img_3)
            control_img = control_img.convert("RGB")
            control_img_list.append(control_img)

        img = pipeline(
            prompt_embeds=conditional_embeds.text_embeds,
            height=gen_config.height,
            width=gen_config.width,
            num_inference_steps=gen_config.num_inference_steps,
            guidance_scale=gen_config.guidance_scale,
            latents=gen_config.latents,
            generator=generator,
            control_img_list=control_img_list,
            **extra,
        ).images[0]
        return img

    def get_noise_prediction(
        self,
        latent_model_input: torch.Tensor,
        timestep: torch.Tensor,  # 0 to 1000 scale
        text_embeddings: PromptEmbeds,
        guidance_embedding_scale: float,
        batch: "DataLoaderBatchDTO" = None,
        **kwargs,
    ):
        with torch.no_grad():
            txt, txt_ids = batched_prc_txt(text_embeddings.text_embeds)
            packed_latents, img_ids = batched_prc_img(latent_model_input)

            # prepare image conditioning if any
            img_cond_seq: torch.Tensor | None = None
            img_cond_seq_ids: torch.Tensor | None = None

            # handle control images
            if batch.control_tensor_list is not None:
                batch_size, num_channels_latents, height, width = (
                    latent_model_input.shape
                )

                control_image_max_res = 1024 * 1024
                if self.model_config.model_kwargs.get("match_target_res", False):
                    # use the current target size to set the control image res
                    control_image_res = (
                        height
                        * self.pipeline.vae_scale_factor
                        * width
                        * self.pipeline.vae_scale_factor
                    )
                    control_image_max_res = control_image_res

                if len(batch.control_tensor_list) != batch_size:
                    raise ValueError(
                        "Control tensor list length does not match batch size"
                    )
                for control_tensor_list in batch.control_tensor_list:
                    # control tensor list is a list of tensors for this batch item
                    controls = []
                    # pack control
                    for control_img in control_tensor_list:
                        # control images are 0 - 1 scale, shape (1, ch, height, width)
                        control_img = control_img.to(
                            self.device_torch, dtype=self.torch_dtype
                        )
                        # if it is only 3 dim, add batch dim
                        if len(control_img.shape) == 3:
                            control_img = control_img.unsqueeze(0)

                        # resize to fit within max res while keeping aspect ratio
                        if self.model_config.model_kwargs.get(
                            "match_target_res", False
                        ):
                            ratio = control_img.shape[2] / control_img.shape[3]
                            c_width = math.sqrt(control_image_res * ratio)
                            c_height = c_width / ratio

                            c_width = round(c_width / 32) * 32
                            c_height = round(c_height / 32) * 32

                            control_img = F.interpolate(
                                control_img, size=(c_height, c_width), mode="bilinear"
                            )

                        # scale to -1 to 1
                        control_img = control_img * 2 - 1
                        controls.append(control_img)

                    img_cond_seq_item, img_cond_seq_ids_item = encode_image_refs(
                        self.vae, controls, limit_pixels=control_image_max_res
                    )
                    if img_cond_seq is None:
                        img_cond_seq = img_cond_seq_item
                        img_cond_seq_ids = img_cond_seq_ids_item
                    else:
                        img_cond_seq = torch.cat(
                            (img_cond_seq, img_cond_seq_item), dim=0
                        )
                        img_cond_seq_ids = torch.cat(
                            (img_cond_seq_ids, img_cond_seq_ids_item), dim=0
                        )

            img_input = packed_latents
            img_input_ids = img_ids

            if img_cond_seq is not None:
                assert img_cond_seq_ids is not None, (
                    "You need to provide either both or neither of the sequence conditioning"
                )
                img_input = torch.cat((img_input, img_cond_seq), dim=1)
                img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)

            guidance_vec = torch.full(
                (img_input.shape[0],),
                guidance_embedding_scale,
                device=img_input.device,
                dtype=img_input.dtype,
            )

            cast_dtype = self.model.dtype

        packed_noise_pred = self.transformer(
            x=img_input.to(self.device_torch, cast_dtype),
            x_ids=img_input_ids.to(self.device_torch),
            timesteps=timestep.to(self.device_torch, cast_dtype) / 1000,
            ctx=txt.to(self.device_torch, cast_dtype),
            ctx_ids=txt_ids.to(self.device_torch),
            guidance=guidance_vec.to(self.device_torch, cast_dtype),
        )

        if img_cond_seq is not None:
            packed_noise_pred = packed_noise_pred[:, : packed_latents.shape[1]]

        if isinstance(packed_noise_pred, QTensor):
            packed_noise_pred = packed_noise_pred.dequantize()

        noise_pred = torch.cat(scatter_ids(packed_noise_pred, img_ids)).squeeze(2)

        return noise_pred

    def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
        if self.pipeline.text_encoder.device != self.device_torch:
            self.pipeline.text_encoder.to(self.device_torch)

        prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
            prompt, device=self.device_torch
        )
        pe = PromptEmbeds(prompt_embeds)
        return pe

    def get_model_has_grad(self):
        return False

    def get_te_has_grad(self):
        return False

    def save_model(self, output_path, meta, save_dtype):
        if not output_path.endswith(".safetensors"):
            output_path = output_path + ".safetensors"
        # only save the unet
        transformer: Flux2 = unwrap_model(self.model)
        state_dict = transformer.state_dict()
        save_dict = {}
        for k, v in state_dict.items():
            if isinstance(v, QTensor):
                v = v.dequantize()
            save_dict[k] = v.clone().to("cpu", dtype=save_dtype)

        meta = get_meta_for_safetensors(meta, name="flux2")
        save_file(save_dict, output_path, metadata=meta)

    def get_loss_target(self, *args, **kwargs):
        noise = kwargs.get("noise")
        batch = kwargs.get("batch")
        return (noise - batch.latents).detach()

    def get_base_model_version(self):
        return "flux2"

    def get_transformer_block_names(self) -> Optional[List[str]]:
        return ["double_blocks", "single_blocks"]

    def convert_lora_weights_before_save(self, state_dict):
        new_sd = {}
        for key, value in state_dict.items():
            new_key = key.replace("transformer.", "diffusion_model.")
            new_sd[new_key] = value
        return new_sd

    def convert_lora_weights_before_load(self, state_dict):
        new_sd = {}
        for key, value in state_dict.items():
            new_key = key.replace("diffusion_model.", "transformer.")
            new_sd[new_key] = value
        return new_sd

    def encode_images(self, image_list: List[torch.Tensor], device=None, dtype=None):
        if device is None:
            device = self.vae_device_torch
        if dtype is None:
            dtype = self.vae_torch_dtype

        # Move to vae to device if on cpu
        if self.vae.device == torch.device("cpu"):
            self.vae.to(device)
        # move to device and dtype
        image_list = [image.to(device, dtype=dtype) for image in image_list]
        images = torch.stack(image_list).to(device, dtype=dtype)

        latents = self.vae.encode(images)

        return latents
